# Licensed under a 3-clause BSD style license - see LICENSE.rst

from types import FunctionType
from contextlib import contextmanager
from functools import wraps

from astropy.table import QTable

__all__ = ['BaseTimeSeries', 'autocheck_required_columns']

COLUMN_RELATED_METHODS = ['add_column',
                          'add_columns',
                          'keep_columns',
                          'remove_column',
                          'remove_columns',
                          'rename_column']


def autocheck_required_columns(cls):
    """
    This is a decorator that ensures that the table contains specific
    methods indicated by the _required_columns attribute. The aim is to
    decorate all methods that might affect the columns in the table and check
    for consistency after the methods have been run.
    """

    def decorator_method(method):

        @wraps(method)
        def wrapper(self, *args, **kwargs):
            result = method(self, *args, **kwargs)
            self._check_required_columns()
            return result

        return wrapper

    for name in COLUMN_RELATED_METHODS:
        if (not hasattr(cls, name) or
                not isinstance(getattr(cls, name), FunctionType)):
            raise ValueError(f"{name} is not a valid method")
        setattr(cls, name, decorator_method(getattr(cls, name)))

    return cls


class BaseTimeSeries(QTable):

    _required_columns = None
    _required_columns_enabled = True

    # If _required_column_relax is True, we don't require the columns to be
    # present but we do require them to be the correct ones IF present. Note
    # that this is a temporary state - as soon as the required columns
    # are all present, we toggle this to False
    _required_columns_relax = False

    def _check_required_columns(self):
        def as_scalar_or_list_str(obj):
            if not hasattr(obj, "__len__"):
                return f"'{obj}'"
            elif len(obj) == 1:
                return f"'{obj[0]}'"
            else:
                return str(obj)

        if not self._required_columns_enabled:
            return

        if self._required_columns is not None:

            if self._required_columns_relax:
                required_columns = self._required_columns[:len(self.colnames)]
            else:
                required_columns = self._required_columns

            plural = 's' if len(required_columns) > 1 else ''

            if not self._required_columns_relax and len(self.colnames) == 0:

                raise ValueError("{} object is invalid - expected '{}' "
                                 "as the first column{} but time series has no columns"
                                 .format(self.__class__.__name__, required_columns[0], plural))

            elif self.colnames[:len(required_columns)] != required_columns:

                raise ValueError("{} object is invalid - expected {} "
                                 "as the first column{} but found {}"
                                 .format(self.__class__.__name__, as_scalar_or_list_str(required_columns),
                                            plural, as_scalar_or_list_str(self.colnames[:len(required_columns)])))

            if (self._required_columns_relax
                    and self._required_columns == self.colnames[:len(self._required_columns)]):
                self._required_columns_relax = False

    @contextmanager
    def _delay_required_column_checks(self):
        self._required_columns_enabled = False
        yield
        self._required_columns_enabled = True
        self._check_required_columns()
